{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试 ONNX Relax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建缓存目录："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建 ONNX 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "class M(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)\n",
    "        self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x = self.conv(x)\n",
    "        x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode=\"nearest\",)\n",
    "        return x\n",
    "\n",
    "\n",
    "torch_model = M()\n",
    "input_tensor = torch.randn(1, 3, 10, 10)\n",
    "torch.onnx.export(\n",
    "    torch_model, \n",
    "    (input_tensor,), \n",
    "    temp_dir/\"test.onnx\", \n",
    "    input_names=[\"x\"],\n",
    "    opset_version=11,\n",
    ")\n",
    "torch.onnx.export(\n",
    "    torch_model, \n",
    "    (input_tensor,), \n",
    "    temp_dir/\"test19.onnx\", \n",
    "    input_names=[\"x\"],\n",
    "    opset_version=19,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 转换 ONNX 模型为 Relax 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error converting operator Resize, with inputs: [x, metadata[\"relax.expr.Constant\"][0]\n",
      "# Metadata omitted. Use show_meta=True in script() method to show it., metadata[\"relax.expr.Constant\"][0]\n",
      "# Metadata omitted. Use show_meta=True in script() method to show it.]\n"
     ]
    },
    {
     "ename": "TVMError",
     "evalue": "Traceback (most recent call last):\n  File \"/media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h\", line 924\nTVMError: In function relax.op.image.resize2d(0: RelaxExpr, 1: RelaxExpr, 2: Array<FloatImm>, 3: runtime.String, 4: runtime.String, 5: runtime.String, 6: runtime.String, 7: double, 8: int, 9: double, 10: DataType) -> RelaxExpr: error while converting argument 2: [17:25:38] /media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h:2274: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[runtime.Object], but got relax.expr.Call\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTVMError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtvm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrelax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfrontend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m from_onnx\n\u001b[1;32m      3\u001b[0m model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(temp_dir\u001b[38;5;241m/\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m tvm_model \u001b[38;5;241m=\u001b[39m \u001b[43mfrom_onnx\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[43mkeep_params_in_input\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m20\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3690\u001b[0m, in \u001b[0;36mfrom_onnx\u001b[0;34m(model, shape_dict, dtype_dict, opset, keep_params_in_input, sanitize_input_names)\u001b[0m\n\u001b[1;32m   3683\u001b[0m     warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m   3684\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   3685\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are overwritting original opset ver = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mopset_in_model\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by lower ver = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mopset\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   3686\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThat might cause model conversion errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   3687\u001b[0m     )\n\u001b[1;32m   3689\u001b[0m \u001b[38;5;66;03m# Use the graph proto as a scope so that ops can access other nodes if needed.\u001b[39;00m\n\u001b[0;32m-> 3690\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_onnx\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopset\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3321\u001b[0m, in \u001b[0;36mONNXGraphImporter.from_onnx\u001b[0;34m(self, graph, opset)\u001b[0m\n\u001b[1;32m   3319\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_parse_graph_input(graph)\n\u001b[1;32m   3320\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_for_unsupported_ops(graph)\n\u001b[0;32m-> 3321\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_construct_nodes\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3323\u001b[0m \u001b[38;5;66;03m# now return the outputs\u001b[39;00m\n\u001b[1;32m   3324\u001b[0m outputs \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_nodes[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_parse_value_proto(i)] \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m graph\u001b[38;5;241m.\u001b[39moutput]\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3501\u001b[0m, in \u001b[0;36mONNXGraphImporter._construct_nodes\u001b[0;34m(self, graph)\u001b[0m\n\u001b[1;32m   3499\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m TVMError \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m   3500\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mError converting operator \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mop_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, with inputs: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minputs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 3501\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m err\n\u001b[1;32m   3503\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m op_name \u001b[38;5;129;01min\u001b[39;00m return_tuple_ops:\n\u001b[1;32m   3504\u001b[0m     outputs_num \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3496\u001b[0m, in \u001b[0;36mONNXGraphImporter._construct_nodes\u001b[0;34m(self, graph)\u001b[0m\n\u001b[1;32m   3494\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNode \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnode\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m cannot handle ShapeExpr inputs.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m   3495\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3496\u001b[0m     op \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_convert_operator\u001b[49m\u001b[43m(\u001b[49m\u001b[43mop_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3497\u001b[0m     \u001b[38;5;66;03m# Create struct information for the new operator.\u001b[39;00m\n\u001b[1;32m   3498\u001b[0m     op \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbb\u001b[38;5;241m.\u001b[39mnormalize(op)\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3596\u001b[0m, in \u001b[0;36mONNXGraphImporter._convert_operator\u001b[0;34m(self, op_name, inputs, attrs, opset)\u001b[0m\n\u001b[1;32m   3594\u001b[0m     convert_class \u001b[38;5;241m=\u001b[39m convert_map[op_name]\n\u001b[1;32m   3595\u001b[0m     op_function \u001b[38;5;241m=\u001b[39m convert_class\u001b[38;5;241m.\u001b[39mget_converter(opset)\n\u001b[0;32m-> 3596\u001b[0m     sym \u001b[38;5;241m=\u001b[39m \u001b[43mop_function\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_nodes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_params\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3597\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   3598\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOperator \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m not implemented.\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(op_name))\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:2146\u001b[0m, in \u001b[0;36mResize._impl_v18\u001b[0;34m(cls, bb, inputs, attr, params)\u001b[0m\n\u001b[1;32m   2141\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m   2142\u001b[0m         sizes, relax\u001b[38;5;241m.\u001b[39mConstant\n\u001b[1;32m   2143\u001b[0m     ), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOnly constant output size currently supported.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   2144\u001b[0m     sizes \u001b[38;5;241m=\u001b[39m sizes\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint64\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mtolist()[\u001b[38;5;241m2\u001b[39m:]\n\u001b[0;32m-> 2146\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrelax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresize2d\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2147\u001b[0m \u001b[43m    \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2148\u001b[0m \u001b[43m    \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrelax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mShapeExpr\u001b[49m\u001b[43m(\u001b[49m\u001b[43msizes\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2149\u001b[0m \u001b[43m    \u001b[49m\u001b[43mroi\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mroi\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2150\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlayout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mNCHW\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2151\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2152\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcoordinate_transformation_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcoord_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2153\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrounding_method\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrounding_method\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2154\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcubic_alpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcubic_coeff_a\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2155\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcubic_exclude\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexclude_outside\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2156\u001b[0m \u001b[43m    \u001b[49m\u001b[43mextrapolation_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextrapolation_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/relax/op/image/image.py:116\u001b[0m, in \u001b[0;36mresize2d\u001b[0;34m(data, size, roi, layout, method, coordinate_transformation_mode, rounding_method, cubic_alpha, cubic_exclude, extrapolation_value, out_dtype)\u001b[0m\n\u001b[1;32m    113\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    114\u001b[0m         size \u001b[38;5;241m=\u001b[39m ShapeExpr(size)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_ffi_api\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresize2d\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# type: ignore\u001b[39;49;00m\n\u001b[1;32m    117\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    118\u001b[0m \u001b[43m    \u001b[49m\u001b[43msize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    119\u001b[0m \u001b[43m    \u001b[49m\u001b[43mroi\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    120\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlayout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    121\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    122\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcoordinate_transformation_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    123\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrounding_method\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    124\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcubic_alpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    125\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcubic_exclude\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    126\u001b[0m \u001b[43m    \u001b[49m\u001b[43mextrapolation_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    127\u001b[0m \u001b[43m    \u001b[49m\u001b[43mout_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    128\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339\u001b[0m, in \u001b[0;36mtvm._ffi._cy3.core.PackedFuncBase.__call__\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:284\u001b[0m, in \u001b[0;36mtvm._ffi._cy3.core.FuncCall\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185\u001b[0m, in \u001b[0;36mtvm._ffi._cy3.core.CHECK_CALL\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468\u001b[0m, in \u001b[0;36mraise_last_ffi_error\u001b[0;34m()\u001b[0m\n\u001b[1;32m    462\u001b[0m \u001b[38;5;66;03m# The exception PyObject may contain a large amount of state,\u001b[39;00m\n\u001b[1;32m    463\u001b[0m \u001b[38;5;66;03m# including all stack frames that may be inspected in a later\u001b[39;00m\n\u001b[1;32m    464\u001b[0m \u001b[38;5;66;03m# PDB post-mortem.  Therefore, we must make sure to remove the\u001b[39;00m\n\u001b[1;32m    465\u001b[0m \u001b[38;5;66;03m# underlying PyObject* from the C++ side after we retrieve it.\u001b[39;00m\n\u001b[1;32m    466\u001b[0m _LIB\u001b[38;5;241m.\u001b[39mTVMDropLastPythonError()\n\u001b[0;32m--> 468\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m py_err\n",
      "\u001b[0;31mTVMError\u001b[0m: Traceback (most recent call last):\n  File \"/media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h\", line 924\nTVMError: In function relax.op.image.resize2d(0: RelaxExpr, 1: RelaxExpr, 2: Array<FloatImm>, 3: runtime.String, 4: runtime.String, 5: runtime.String, 6: runtime.String, 7: double, 8: int, 9: double, 10: DataType) -> RelaxExpr: error while converting argument 2: [17:25:38] /media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h:2274: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[runtime.Object], but got relax.expr.Call\n"
     ]
    }
   ],
   "source": [
    "import onnx\n",
    "from tvm.relax.frontend.onnx import from_onnx\n",
    "model = onnx.load(temp_dir/\"test.onnx\")\n",
    "tvm_model = from_onnx(model,  keep_params_in_input=True, opset=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "from tvm.relax.frontend.onnx import from_onnx\n",
    "model = onnx.load(temp_dir/\"test19.onnx\")\n",
    "tvm_model = from_onnx(model,  keep_params_in_input=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m\n",
      "\u001b[0mfrom_onnx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0monnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0monnx_ml_pb2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraphProto\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mshape_dict\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdtype_dict\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNoneType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'float32'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mopset\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mkeep_params_in_input\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msanitize_input_names\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtvm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIRModule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m\n",
      "Convert a ONNX model into an equivalent Relax Function.\n",
      "ONNX graphs are represented as Python Protobuf objects.\n",
      "\n",
      "The current implementation assumes that the input model is after ONNX v1.1.0.\n",
      "\n",
      "Parameters\n",
      "----------\n",
      "model : protobuf object\n",
      "    ONNX ModelProto after ONNX v1.1.0\n",
      "shape_dict : dict of str to tuple, optional\n",
      "    The input shape to the graph\n",
      "dtype_dict : str or dict of str to str, optional\n",
      "    The input types to the graph\n",
      "opset : int, optional\n",
      "    Override to autodetected opset.\n",
      "    This can be helpful for some testing.\n",
      "keep_params_in_input : bool\n",
      "    If True, parameters will be treated as input variables. If false,\n",
      "    parameters are treated as constant and folded directly into the graph.\n",
      "sanitize_input_names : bool, optional\n",
      "    Whether to sanitize the input names to ensure they are valid Relax identifiers.\n",
      "\n",
      "Returns\n",
      "-------\n",
      "mod : tvm.IRModule\n",
      "    The relax module for compilation\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "from_onnx?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvm_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import StringIO\n",
    "from contextlib import redirect_stdout, redirect_stderr\n",
    "import tempfile\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import onnx\n",
    "from tvm.relax.frontend.onnx import from_onnx\n",
    "\n",
    "def test_resize():\n",
    "    class Resize(torch.nn.Module):\n",
    "        def forward(self, x):\n",
    "            x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode=\"nearest\",)\n",
    "            return x\n",
    "\n",
    "    torch_model = Resize()\n",
    "    input_tensor = torch.randn(1, 3, 10, 10)\n",
    "    with tempfile.TemporaryDirectory() as temp_dir:\n",
    "        onnx_path = f\"{temp_dir}/test.onnx\"\n",
    "        torch.onnx.export(\n",
    "            torch_model, \n",
    "            (input_tensor,), \n",
    "            onnx_path, \n",
    "            input_names=[\"x\"],\n",
    "            opset_version=11,\n",
    "        )\n",
    "        model = onnx.load(onnx_path)\n",
    "        # need fix\n",
    "        try:\n",
    "            with redirect_stdout(StringIO()) as sio:\n",
    "                tvm_model = from_onnx(model, keep_params_in_input=True)\n",
    "        except Exception as e:\n",
    "            print(f\"Exception: {e}\")\n",
    "            assert (\n",
    "                sio.getvalue() == \n",
    "                'Error converting operator Resize, with inputs: [x, metadata[\"relax.expr.Constant\"][0]\\n# Metadata omitted. '\n",
    "                'Use show_meta=True in script() method to show it., metadata[\"relax.expr.Constant\"][0]\\n# Metadata omitted. '\n",
    "                'Use show_meta=True in script() method to show it.]\\n'\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from io import StringIO\n",
    "from contextlib import redirect_stdout\n",
    "import numpy as np\n",
    "from onnx import helper, TensorProto\n",
    "from onnxscript import script\n",
    "from onnxscript import FLOAT\n",
    "from onnxscript import opset11 as op\n",
    "from tvm.relax.frontend.onnx import from_onnx\n",
    "\n",
    "def test_resize():\n",
    "    @script()\n",
    "    def Resize(X: FLOAT[1, 3, 20, 20]):\n",
    "        scales = op.Constant(value=helper.make_tensor(\"scales\", TensorProto.FLOAT, (4,), [1, 1, 0.5, 0.5]))\n",
    "        roi = op.Constant(value=helper.make_tensor(\"roi\", TensorProto.FLOAT, (), [10]))\n",
    "        return op.Resize(X, roi=roi, scales=scales,)\n",
    "\n",
    "    onnx_result = Resize(X=np.random.randn(1, 3, 20, 20).astype(\"float32\"))\n",
    "    model = Resize.to_model_proto() # returns an onnx.ModelProto\n",
    "    # need fix\n",
    "    try:\n",
    "        with redirect_stdout(StringIO()) as sio:\n",
    "            tvm_model = from_onnx(model, keep_params_in_input=True)\n",
    "    except Exception as e:\n",
    "        print(f\"Exception: {e}\")\n",
    "        assert (\n",
    "            sio.getvalue() == \n",
    "            'Error converting operator Resize, with inputs: [X, R.const(10.0, \"float32\"), '\n",
    "            'metadata[\"relax.expr.Constant\"][0]\\n# Metadata omitted. '\n",
    "            'Use show_meta=True in script() method to show it.]\\n'\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_resize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
